CAS Logo Open main paper 🔗

5  Partitioning

Objectives

In order to expose systemic disparities, section 6 of the main paper propose to partition policyholders following a relevant fairness local metric. In this section, we aim to perform partitioning of policyholders following proxy vulnerability (section 6.1 of the main paper) and following commercial loading (section 6.2 of the main paper).

Furthermore, we perform an experiment on Scenario 1 to see how effective is partitioning in identifying high proxy vulnerability policyholders. We gain insight on best practices for partitioning through this repeated partitioning experiment.

This section is based on simulations of sec-simul-dataset and estimated metrics from sec-local, and it will serve to build the integrated framework of sec-integrated.

Packages for this section
library(tidyverse)
library(latex2exp)
library(jsonlite)
Data for this section
pregroup_grid_stats <- jsonlite::fromJSON('preds/pregroup_grid_stats.json')
pregroup_pop_stats <- jsonlite::fromJSON('preds/pregroup_pop_stats.json')
preds_sims_stats <- jsonlite::fromJSON('preds/preds_sims_stats.json')
Functions for this section
#### EVTREE ESSENTIALS 
# Function to train and evaluate a model
evaluate_model <- function(params, train_data, valid_data, response_name) {
  tryCatch({
    flog.info("Training model for response: %s with params: %s", response_name, toString(params))
    
    beg_evtree <- Sys.time()
    # Train evtree model
    evtree_model <- evtree(
      resp ~ X1 + X2 + D,
      data = train_data,
      control = evtree.control(
        minsplit = params$minbucket * 2 + 1,
        minbucket = params$minbucket,
        maxdepth = params$maxdepth,
        ntrees = params$ntrees,
        alpha = params$alpha
      ))
    time_evtree <- as.numeric(difftime(Sys.time(), beg_evtree, units = "secs"))
    
    beg_rpart <- Sys.time()
    # Train rpart (deep)
    rpart_model <- rpart(
      resp ~ X1 + X2 + D,
      data = train_data,
      method = 'anova',
      control = rpart.control(
        minsplit = params$minbucket * 2 + 1,
        minbucket = params$minbucket,
        maxdepth = params$maxdepth,
        cp = 0.00001
      )
    )
    
    time_rpart <- as.numeric(difftime(Sys.time(), beg_rpart, units = "secs"))
    
    # Prune rpart model to match the number of leaves in evtree
    evtree_leaves <- length(unique(predict(evtree_model, type = "node")))
    rpart_cptable <- rpart_model$cptable
    
    # Find the largest CP for nsplit <= evtree_leaves - 1, or use cp = 0 if no match
    min_xerror_index <- which.min(rpart_cptable[, "xerror"])
    matching_cp <- rpart_cptable[min_xerror_index, "CP"]
    pruned_rpart_model <- prune(rpart_model, cp = matching_cp)
    
    # Prune rpart model to have at most 8 leaves
    target_8_splits <- 8 - 1
    matching_8_cp <- rpart_cptable[which.min(abs(rpart_cptable[, "nsplit"] - target_8_splits)), 'CP']
    pruned_8_rpart_model <- prune(rpart_model, cp = matching_8_cp)
    
    # Evaluate evtree on validation set
    n <- nrow(valid_data)
    evtree_preds <- predict(evtree_model, newdata = valid_data)
    evtree_mse <- n * log(sum((valid_data$resp - evtree_preds)^2, na.rm = TRUE)/ n) + length(unique(evtree_preds)) * log(n)
    flog.info("Validation MSE for evtree (response: %s): %f", response_name, evtree_mse)
    
    # Evaluate pruned rpart on validation set
    rpart_preds <- predict(pruned_rpart_model, newdata = valid_data, type = "vector")
    rpart_mse <- n * log(sum((valid_data$resp - rpart_preds)^2, na.rm = TRUE)/ n) + length(unique(rpart_preds)) * log(n)
    flog.info("Validation MSE for rpart (response: %s): %f", response_name, rpart_mse)
    
    # Evaluate pruned rpart on validation set
    rpart_8preds <- predict(pruned_8_rpart_model, newdata = valid_data, type = "vector")
    rpart_8mse <-  n * log(sum((valid_data$resp - rpart_8preds)^2, na.rm = TRUE)/ n) + length(unique(rpart_8preds)) * log(n)
    flog.info("Validation MSE for 8rpart (response: %s): %f", response_name, rpart_8mse)
    
    # Return results for both models
    list(
      evtree = list(model = evtree_model, val_mse = evtree_mse,
                    leaves = evtree_leaves,
                    evtree_val_preds = evtree_preds,
                    time = time_evtree),
      rpart = list(model = pruned_rpart_model,
                   val_mse = rpart_mse,
                   leaves = length(unique(round(rpart_preds, 3))),
                   time = time_rpart,
                   rpart_val_preds = rpart_preds,
                   cp = matching_cp),
      rpart_eight = list(model = pruned_8_rpart_model,
                        val_mse = rpart_8mse,
                        leaves = length(unique(round(rpart_8preds, 3))),
                        time = time_rpart,
                        rpart_val_preds = rpart_8preds,
                        cp = matching_8_cp),
      response_name = response_name,
      params = params
    )
  }, error = function(e) {
    flog.error("Error while training model: %s", e$message)
    NULL
  })
}

recode_bottom_top_middle <- function(levels, numeric = FALSE, n_bottom = 3, n_top = 3, factor_values = NULL) {
  # Check if input is valid
  if (is.null(levels) || length(levels) < 1) {
    stop("Input levels must be a non-empty numeric vector.")
  }
  
  # Ensure levels are numeric
  levels <- as.numeric(levels)
  
  # If numeric = TRUE, factor_values must be provided
  if (numeric && (is.null(factor_values) || length(factor_values) < 1)) {
    stop("For numeric = TRUE, a valid 'factor_values' vector must be provided.")
  }
  
  # Handle cases where the total number of levels is insufficient for grouping
  if (length(levels) <= (n_bottom + n_top + 1)) {
    return(factor(levels, levels = as.character(levels))) # No middle group possible
  }
  
  # Define the bottom and top ranges dynamically based on the original order
  bottom <- levels[1:n_bottom]  # First `n_bottom` values
  top <- levels[(length(levels) - n_top + 1):length(levels)]  # Last `n_top` values
  middle <- levels[!(levels %in% c(bottom, top))]  # Remaining middle
  
  # Dynamically create label for the middle group
  if (length(middle) > 1) {
    if (numeric) {
      # Subset the factor_values for rows where the levels belong to the middle group
      middle_values <- factor_values[factor_values %in% middle]
      middle_mean <- round(mean(as.numeric(as.character(middle_values)), na.rm = TRUE), 3) # Compute mean
      middle_label <- middle_mean
    } else {
      middle_label <- paste0(min(middle), "-", max(middle)) # Range as label
    }
  } else if (length(middle) == 1) {
    middle_label <- as.character(middle)  # Single value retains its original label
  } else {
    middle_label <- NULL
  }
  
  # Create a mapping for novel levels
  recoded <- sapply(levels, function(x) {
    if (x %in% bottom) return(as.character(x))       # Keep bottom levels as is
    if (x %in% middle) return(as.character(middle_label)) # Group middle levels dynamically
    if (x %in% top) return(as.character(x))          # Keep top levels as is
  })
  
  # Return as a factor with ordered levels maintaining the original order
  return(factor(recoded, levels = unique(recoded), ordered = TRUE))
}

# Function to extract paths for all terminal nodes from any tree
extract_paths_for_all_terminals <- function(tree) {
  # Convert the tree to a party object (if not already)
  party_tree <- tree #as.party(tree)
  
  # Get terminal node IDs
  terminal_nodes <- nodeids(party_tree, terminal = TRUE)
  
  extract_path_for_terminal <- function(party_tree, terminal_node_index) {
    # Get the root node of the tree
    root_node <- party_tree$node
    
    # Retrieve variable names from the tree's data
    var_names <- names(party_tree$data)
    
    # Get terminal node IDs
    terminal_nodes <- partykit::nodeids(party_tree, terminal = TRUE)
    
    # Ensure the requested terminal node index is valid
    if (terminal_node_index > length(terminal_nodes) || terminal_node_index < 1) {
      stop("Terminal node index exceeds the number of terminal nodes.")
    }
    
    # Get the node ID corresponding to the terminal node index
    terminal_node_id <- terminal_nodes[terminal_node_index]
    
    extract_operator_from_second_line <- function(node) {
      # Capture the printed output as a character vector
      node_text <- capture.output(print(node))
      
      # Ensure there is a second line in the printed output
      if (length(node_text) < 2) {
        stop("The printed output does not contain enough lines to extract the operator.")
      }
      
      # Extract the second line
      second_line <- node_text[2]
      
      # Check for the presence of "<" or ">="
      if (grepl("<", second_line)) {
        return("<")
      } else if (grepl(">=", second_line)) {
        return(">=")
      } else {
        stop("No valid operator (< or >=) found in the second line.")
      }
    }
    
    # Recursive function to traverse the tree and build the path
    traverse_tree <- function(node, path = NULL) {
      # If this is the designated terminal node, return the path
      if (node$id == terminal_node_id) {
        return(path)
      }
      
      # If there are no children (terminal node) and it's not the designated one, return NULL
      if (is.null(node$kids)) {
        return(NULL)
      }
      
      # Extract split information
      split_variable <- var_names[node$split$varid]
      split_value <- node$split$breaks
      
      is_categ <- FALSE
      
      # Check which child corresponds to the "left" (< split_value) branch
      left_child_is_kid1 <- grepl('<', extract_operator_from_second_line(print(node)))
      
      if (left_child_is_kid1) {
        # Traverse the left child
        if (terminal_node_id %in% nodeids(node$kids[[1]])) {
          left_path <- c(path, paste0(split_variable, " < ", split_value))
          return(traverse_tree(node$kids[[1]], left_path))
        }
        # Traverse the right child
        if (terminal_node_id %in% nodeids(node$kids[[2]])) {
          right_path <- c(path, paste0(split_variable, " >= ", split_value))
          return(traverse_tree(node$kids[[2]], right_path))
        }
      } else {
        # Traverse the right child
        if (terminal_node_id %in% nodeids(node$kids[[1]])) {
          right_path <- c(path, paste0(split_variable, " >= ", split_value))
          return(traverse_tree(node$kids[[1]], right_path))
        }
        # Traverse the left child
        if (terminal_node_id %in% nodeids(node$kids[[2]])) {
          left_path <- c(path, paste0(split_variable, " < ", split_value))
          return(traverse_tree(node$kids[[2]], left_path))
        }
      }
      
      return(NULL)  # No path found
    }
    
    # Start traversing from the root node
    path <- traverse_tree(root_node)
    
    # Return the path as a list of conditions
    return(path)
  }
  
  # Loop through each terminal node and extract the path
  paths <- lapply(seq_along(terminal_nodes), function(index) {
    extract_path_for_terminal(party_tree, index)
  })
  
  # Combine the paths with their corresponding terminal node IDs
  names(paths) <- terminal_nodes
  simplified_paths = setNames(nm = names(paths)) %>% lapply(function(name_p){
    simplify_path(path_conditions =paths[[name_p]], round_digits = 2)
  }) 
  
  return(list('full' = paths,
              'full_concat' = lapply(paths, function(p) paste(p, collapse = ' AND ')),
              'simp' = simplified_paths,
              'simp_concat' = lapply(simplified_paths, function(p) paste(p, collapse = ' AND '))
  )
  )
}


simplify_path <- function(path_conditions, round_digits = NULL) {
  # Parse conditions into variable, operator, and value
  parsed_conditions <- lapply(path_conditions, function(cond) {
    matches <- regmatches(cond, regexec("^(\\w+)\\s*([<>]=?)\\s*(.*)$", cond))[[1]]
    list(var = matches[2], op = matches[3], val = as.numeric(matches[4]))
  })
  
  # Group conditions by variable
  grouped_conditions <- split(parsed_conditions, sapply(parsed_conditions, function(x) x$var))
  
  # Simplify conditions for each variable
  simplified_conditions <- lapply(grouped_conditions, function(conds) {
    # Separate "<" and ">" conditions
    less_than <- conds[sapply(conds, function(x) x$op %in% c("<", "<="))]
    greater_than <- conds[sapply(conds, function(x) x$op %in% c(">", ">="))]
    
    # Process the most restrictive "<" condition
    if (length(less_than) > 0) {
      max_less_than <- less_than[[which.min(sapply(less_than, function(x) x$val))]]
    } else {
      max_less_than <- NULL
    }
    
    # Process the most restrictive ">" condition
    if (length(greater_than) > 0) {
      max_greater_than <- greater_than[[which.max(sapply(greater_than, function(x) x$val))]]
    } else {
      max_greater_than <- NULL
    }
    
    # Apply rounding if specified
    if (!is.null(round_digits)) {
      if (!is.null(max_less_than)) max_less_than$val <- round(max_less_than$val, round_digits)
      if (!is.null(max_greater_than)) max_greater_than$val <- round(max_greater_than$val, round_digits)
    }
    
    # Recombine conditions into a valid range or inequality
    if (!is.null(max_less_than) && !is.null(max_greater_than)) {
      if (max_greater_than$val >= max_less_than$val) {
        stop("Conflicting conditions for variable ", max_less_than$var)
      }
      paste0(max_greater_than$val, " < ", max_greater_than$var, " < ", max_less_than$val)
    } else if (!is.null(max_less_than)) {
      paste0(max_less_than$var, " ", max_less_than$op, " ", max_less_than$val)
    } else if (!is.null(max_greater_than)) {
      paste0(max_greater_than$var, " ", max_greater_than$op, " ", max_greater_than$val)
    }
  })
  
  # Flatten the simplified conditions
  return(unlist(simplified_conditions))
}

5.1 Partitioning policyholders following proxy_vulnerability

We apply an optimal partitioning algorithm, evtree from Grubinger, Zeileis, and Pfeiffer (2014), to policyholders based on proxy vulnerability in the three scenarios of the example. We use \((X_1, X_2)\) as the feature space for partitioning and impose strong regularization to limit the number of groups.

Figure fig-ex_partitioning_clusters presents the results for the three scenarios. The top row shows estimated proxy vulnerability, with colors indicating the groups resulting from the optimal partition of proxy vulnerability. While the left panel may not match intuition in terms of the number of groups in scenario 1, the predicted values of proxy vulnerability based on the evtree align with expectations: darker red indicates individuals most vulnerable to proxy effects. The bottom row of Figure fig-ex_partitioning_clusters depicts the partition in the \((x_1, x_2)\) domain. The structure aligns with the example design: high proxy vulnerability for individuals with \(x_2 = 4\) and large \(x_1\), and important variation in proxy vulnerability across \(x_2\).

Training the evtrees per scenario
source("___train_evtree_scenario.R")

pregroup_pop_stats_small <- setNames(nm = names(pregroup_pop_stats)) %>% lapply(function(pop_name){
  setNames(nm = names(pregroup_pop_stats[[pop_name]])) %>% lapply(function(the_set){
    the_frac <- ifelse(the_set == 'train', 0.1 , 1)
    pregroup_pop_stats[[pop_name]][[the_set]] %>% 
      sample_frac(the_frac)
  })
})

# Define hyperparameter grid
param_grid <- expand.grid(
  minbucket = c(0.03, 0.05) * nrow(pregroup_pop_stats_small$Scenario1$train), 
  maxdepth = c(3, 4),
  alpha = c(1, 2),
  ntrees = 25,
  stringsAsFactors = FALSE
)


output_dir <- "evtree" # Directory to save models
response_vars <- c("proxy_vuln", 'comm_load') # List of response variables
  
# Call process_populations with actual inputs
my_trees <- process_populations(preds_pop_stats = pregroup_pop_stats_small,
                                response_vars = response_vars,
                                param_grid = param_grid,
                                output_dir)
INFO [2025-10-06 14:35:09] Processing response variable: proxy_vuln
INFO [2025-10-06 14:35:09] Processing population: Scenario1 for response: proxy_vuln
INFO [2025-10-06 14:35:09] Model for population Scenario1 and response proxy_vuln already exists. Loading...
INFO [2025-10-06 14:35:10] Processing population: Scenario2 for response: proxy_vuln
INFO [2025-10-06 14:35:10] Model for population Scenario2 and response proxy_vuln already exists. Loading...
INFO [2025-10-06 14:35:10] Processing population: Scenario3 for response: proxy_vuln
INFO [2025-10-06 14:35:10] Model for population Scenario3 and response proxy_vuln already exists. Loading...
INFO [2025-10-06 14:35:10] Processing response variable: comm_load
INFO [2025-10-06 14:35:10] Processing population: Scenario1 for response: comm_load
INFO [2025-10-06 14:35:10] Model for population Scenario1 and response comm_load already exists. Loading...
INFO [2025-10-06 14:35:12] Processing population: Scenario2 for response: comm_load
INFO [2025-10-06 14:35:12] Model for population Scenario2 and response comm_load already exists. Loading...
INFO [2025-10-06 14:35:12] Processing population: Scenario3 for response: comm_load
INFO [2025-10-06 14:35:12] Model for population Scenario3 and response comm_load already exists. Loading...
Figure 5.1: Estimated propensity in terms of \(x_1\) and \(x_2\) for simulations
Code for the visualisation of the evtrees
library(rpart)
library(ggparty, partykit)

temp_tree <- c('evtree', 'rpart') %>% lapply(function(the_algo){
  names(pregroup_grid_stats) %>% lapply(function(pop_name){
pop_id <- which(names(pregroup_grid_stats) == pop_name)

# Compute sequential terminal node IDs
party_tree <- my_trees$proxy_vuln[[pop_name]][[paste0('best_', the_algo)]]$model
if (the_algo == 'rpart'){
party_tree <- partykit::as.party(party_tree)
}

terminal_ids <- nodeids(party_tree, terminal = TRUE)  # Original terminal node IDs
sequential_ids <- seq_along(terminal_ids)  # Create sequential IDs
id_mapping <- data.frame(terminal_id = terminal_ids, sequential_id = sequential_ids)


## Compute average prediction per terminal node
# Extract predictions and terminal node IDs
predictions <- fitted(party_tree)
avg_prediction <- aggregate(`(response)` ~ `(fitted)`,
                            data = predictions,
                            FUN = mean)

tree_plot <- ggparty(party_tree) +
  geom_edge() +
  geom_edge_label(mapping = aes(label = !!sym("breaks_label")),
                  size = 3) +
  geom_node_label(
    line_list = list(
      aes(label = splitvar),
      aes(label = paste("N =", nodesize))
    ),
    line_gpar = list(
      list(size = 10),
      list(size = 8)
    ),
    ids = "inner",
  ) +
  geom_node_label(
    line_list = list(
      aes(label = paste0("Node ",
                         match(id, id_mapping$terminal_id),
                         ", N = ",
                         nodesize)),
      aes(label = paste0("Avg Pred. = ",
                                     round(avg_prediction$`(response)`[match(id, avg_prediction$`(fitted)`)], 2)))
    ),
    line_gpar = list(
      list(size = 8),
      list(size = 10)
    ),
    ids = "terminal", nudge_y = -0.45, nudge_x = 0.01,
    label.size = 0.15,
    size = 3
  ) +
  geom_node_plot(
    gglist = list(
      geom_boxplot(aes(x = "", y = resp,
                       color = ..middle..,
                       fill = ..middle..),  # Color by median
                   outlier.color = "black"
                   , alpha = 0.7
                   ),
      theme_minimal(),
      scale_fill_gradient2(
        low = "#D7CC39", mid = "grey75", high = "#CAA8F5",
        midpoint = 0, name = "Median Value"
      ),
      scale_color_gradient2(
        low = colorspace::darken("#D7CC39", 0.3), mid = colorspace::darken("grey75", 0.3),
        high = colorspace::darken("#CAA8F5", 0.3),
      midpoint = 0, name = "Median Value"
      ),
      xlab(""), ylab(latex2exp::TeX("$\\widehat{\\Delta}_{proxy}(X_1, X_2)$")),
      scale_y_continuous(labels = scales::dollar),
      theme(axis.text.x = element_blank(),
            axis.title.y = element_text(margin = margin(l = -10)),
            axis.title.x = element_text(margin = margin(r = 20)))
    ),
    shared_axis_labels = TRUE
  ) +
  ggtitle(latex2exp::TeX(paste0('Partition of proxy vulnerable individuals for scenario ', pop_id))) +
  theme(
    plot.title = element_text(size = 16, face = "bold", hjust = 0.5)
  )
  
}) %>% ggpubr::ggarrange(plotlist = .,
                           nrow = 3,
                           widths = 15, heights = 1,
                           common.legend = T,
                           legend = 'right') %>% 
ggsave(filename = paste0("figs/graph_trees_", the_algo,"_proxy.png"),
       plot = .,
       height = 16,
       width = 12,
       units = "in",
       device = "png", dpi = 500)
})
rm(temp_tree)

5.1.1 Saving important quantites

Saving dictionnaries of partition rules and predictions
dictionnary_leaves_trees <- setNames(nm = names(my_trees)) %>%  lapply(function(resp_tree){
  setNames(nm = names(my_trees[[resp_tree]])) %>% lapply(function(pop_name){
      temp_to_pred <- pregroup_grid_stats[[pop_name]]
      names(temp_to_pred) <- toupper(names(temp_to_pred))
      
      model_ev <- my_trees[[resp_tree]][[pop_name]]$best_evtree$model
      model_rpart <- my_trees[[resp_tree]][[pop_name]]$best_rpart$model 
      model_rpart_prune <- prune(model_rpart,
                        cp = model_rpart$cptable[which(model_rpart$cptable[, "nsplit"] + 1 == 8), "CP"]) %>% as.party()
      model_rpart <- model_rpart %>% as.party()
      
      to_return_evtree <- data.frame('node_or' = predict(model_ev, newdata = temp_to_pred,
                                  type = 'node') %>% unname,
                                  'pred' = predict(model_ev, newdata = temp_to_pred,
                                                   type = 'response') %>% unname %>%
                                    round(3)) %>% distinct() %>% arrange(-pred) %>%
        mutate('node_new' = 1:n())
      
      to_return_rpart <- data.frame('node_or' = predict(model_rpart, newdata = temp_to_pred,
                                  type = 'node') %>% unname,
                                  'pred' = predict(model_rpart, newdata = temp_to_pred, type = 'response') %>% unname %>%
                                    round(3)) %>% distinct() %>% arrange(-pred) %>%
        mutate('node_new' = 1:n())
      
      to_return_rpart_prune <- data.frame('node_or' = predict(model_rpart_prune, newdata = temp_to_pred,
                                  type = 'node') %>% unname,
                                  'pred' = predict(model_rpart_prune, newdata = temp_to_pred,
                                                   type = 'response') %>% unname %>%
                                    round(3)) %>% distinct() %>% arrange(-pred) %>%
        mutate('node_new' = 1:n())
      
      paths_ev <- extract_paths_for_all_terminals(tree = model_ev)
      paths_rpart <- extract_paths_for_all_terminals(tree = model_rpart)
      paths_rpart_prune <- extract_paths_for_all_terminals(tree = model_rpart_prune)
      
       list('evtree' = list('dict' = to_return_evtree,
                            'model' = model_ev,
                            'paths' = paths_ev),
           'rpart' = list('dict' = to_return_rpart,
                          'model' = model_rpart,
                          'paths' = paths_rpart),
           'rpart_prune' = list('dict' = to_return_rpart_prune,
                                'model' = model_rpart_prune,
                                'paths' = paths_rpart_prune))
    })
})

saveRDS(dictionnary_leaves_trees, 'evtree/dictionnary_leaves_trees.rds')

### Applying partition to the data
group_grid_path = 'preds/group_grid_stats.json'
group_pop_path = 'preds/group_pop_stats.json'


# Check and load or compute group_grid_stats
if (file.exists(group_grid_path)) {
  temp_grid_stats <- fromJSON(group_grid_path) 
  
  
  group_grid_stats <- setNames(nm = names(temp_grid_stats)) |> lapply(function(pop_name){
    temp_grid_stats[[pop_name]] |> 
      mutate(proxy_g_evtree = proxy_g_evtree %>% factor(., levels = sort(unique(as.numeric(proxy_g_evtree)), decreasing = T)),
           proxy_g_rpart = proxy_g_rpart %>% factor(., levels = sort(unique(as.numeric(proxy_g_rpart)), decreasing = T)),
           cload_g_evtree = cload_g_evtree %>% factor(., levels = sort(unique(as.numeric(cload_g_evtree)), decreasing = T)),
           cload_g_rpart = cload_g_rpart %>% factor(., levels = sort(unique(as.numeric(cload_g_rpart)), decreasing = T)))
  })
  
  rm(temp_grid_stats)
    
} else {
  group_grid_stats <- setNames(nm = names(pregroup_grid_stats)) %>% lapply(function(pop_name){
    temp_to_pred <-   pregroup_grid_stats[[pop_name]]
    names(temp_to_pred) <- toupper(names(temp_to_pred))
    
    pred_proxy_g_evtree <- predict(my_trees$proxy_vuln[[pop_name]]$best_evtree$model,
                                          newdata = temp_to_pred) %>% round(3) %>% unname
    pred_proxy_g_rpart <- predict(my_trees$proxy_vuln[[pop_name]]$best_rpart$model,
                                          newdata = temp_to_pred) %>% round(3) %>% unname
    pred_cload_g_evtree <- predict(my_trees$comm_load[[pop_name]]$best_evtree$model,
                                          newdata = temp_to_pred) %>% round(3) %>% unname
    pred_cload_g_rpart <- predict(my_trees$comm_load[[pop_name]]$best_rpart$model,
                                          newdata = temp_to_pred) %>% round(3) %>% unname
    
    data.frame(pregroup_grid_stats[[pop_name]],
                 proxy_g_evtree = pred_proxy_g_evtree %>% factor(., levels = sort(unique(pred_proxy_g_evtree), decreasing = T)),
                 proxy_g_rpart = pred_proxy_g_rpart %>% factor(., levels = sort(unique(pred_proxy_g_rpart), decreasing = T)),
                 cload_g_evtree = pred_cload_g_evtree %>% factor(., levels = sort(unique(pred_cload_g_evtree), decreasing = T)),
                 cload_g_rpart = pred_cload_g_rpart %>% factor(., levels = sort(unique(pred_cload_g_rpart), decreasing = T))
               )
  })
  toJSON(group_grid_stats, pretty = TRUE, auto_unbox = TRUE) %>% 
    write(group_grid_path)
}

# Check and load or compute group_pop_stats
if (file.exists(group_pop_path)) {
   temp_pop_stats <- fromJSON(group_pop_path) 
  
  
  group_pop_stats <- setNames(nm = names(temp_pop_stats)) |> lapply(function(pop_name){
    setNames(nm = names(temp_pop_stats[[pop_name]])) |> lapply(function(the_set){
      temp_pop_stats[[pop_name]][[the_set]] |> 
        mutate(proxy_g_evtree = proxy_g_evtree %>% factor(., levels = sort(unique(as.numeric(proxy_g_evtree)), decreasing = T)),
           proxy_g_rpart = proxy_g_rpart %>% factor(., levels = sort(unique(as.numeric(proxy_g_rpart)), decreasing = T)),
           cload_g_evtree = cload_g_evtree %>% factor(., levels = sort(unique(as.numeric(cload_g_evtree)), decreasing = T)),
           cload_g_rpart = cload_g_rpart %>% factor(., levels = sort(unique(as.numeric(cload_g_rpart)), decreasing = T)))
    }) 
  })
  
  rm(temp_pop_stats)
  
} else {
  group_pop_stats <- setNames(nm = names(pregroup_pop_stats)) %>% lapply(function(pop_name){
    setNames(nm = names(pregroup_pop_stats[[pop_name]])) %>% lapply(function(set){
      
      temp_to_pred <- pregroup_pop_stats[[pop_name]][[set]]
      
    pred_proxy_g_evtree <- predict(my_trees$proxy_vuln[[pop_name]]$best_evtree$model,
                                          newdata = temp_to_pred) %>% round(3) %>% unname
    pred_proxy_g_rpart <- predict(my_trees$proxy_vuln[[pop_name]]$best_rpart$model,
                                          newdata = temp_to_pred) %>% round(3) %>% unname
    pred_cload_g_evtree <- predict(my_trees$comm_load[[pop_name]]$best_evtree$model,
                                          newdata = temp_to_pred) %>% round(3) %>% unname
    pred_cload_g_rpart <- predict(my_trees$comm_load[[pop_name]]$best_rpart$model,
                                          newdata = temp_to_pred) %>% round(3) %>% unname
    
    data.frame(pregroup_pop_stats[[pop_name]][[set]],
                 proxy_g_evtree = pred_proxy_g_evtree %>% factor(., levels = sort(unique(pred_proxy_g_evtree), decreasing = T)),
                 proxy_g_rpart = pred_proxy_g_rpart %>% factor(., levels = sort(unique(pred_proxy_g_rpart), decreasing = T)),
                cload_g_evtree = pred_cload_g_evtree %>% factor(., levels = sort(unique(pred_cload_g_evtree), decreasing = T)),
                 cload_g_rpart = pred_cload_g_rpart %>% factor(., levels = sort(unique(pred_cload_g_rpart), decreasing = T))
                )
    })  
})
  toJSON(group_pop_stats, pretty = TRUE, auto_unbox = TRUE) %>% 
    write(group_pop_path)
}

5.1.2 Visualizing the partition

Graph for the illustration of the partitioning
n_bottom <- 5
n_top <- 5

setNames(nm = names(group_pop_stats)) %>%  lapply(function(pop_name){
  ## the colors
pop_id <- which(names(group_pop_stats) == pop_name)
    
local_to_g <- group_grid_stats[[pop_name]] %>% 
filter(x1 <= 8, x1 >= -5, d == 1) 

if(pop_name == head(names(group_grid_stats), 1)){
  the_y_scale_top <- scale_y_continuous(labels = scales::dollar, breaks = c(-5, 0, 5, 10), limits = c(-6, 14))
  the_y_label_top <- latex2exp::TeX("$\\Delta_{proxy}(x_1, x_2)$")
  the_y_scale <- scale_y_discrete()
  the_y_label <- latex2exp::TeX('$x_2$')
} else {
  the_y_scale_top <- scale_y_continuous(labels = NULL, breaks = c(-5, 0, 5, 10), limits = c(-6, 14))
  the_y_label_top <- NULL
    the_y_scale <-scale_y_discrete(labels = NULL)
  the_y_label <- NULL
}

local_pop_g <- group_pop_stats[[pop_name]]$valid

local_to_g$proxy_g_evtree_g <- local_to_g$proxy_g_evtree
local_pop_g$proxy_g_evtree_g <- local_pop_g$proxy_g_evtree
levels(local_to_g$proxy_g_evtree_g) <- recode_bottom_top_middle(levels = levels(local_to_g$proxy_g_evtree), 
                                                          numeric = TRUE, 
                                                          n_bottom = n_bottom, 
                                                          n_top = n_top, 
                                                          factor_values = local_to_g$proxy_g_evtree)
levels(local_pop_g$proxy_g_evtree_g) <- recode_bottom_top_middle(levels = levels(local_pop_g$proxy_g_evtree), 
                                                          numeric = TRUE, 
                                                          n_bottom = n_bottom, 
                                                          n_top = n_top, 
                                                          factor_values = local_pop_g$proxy_g_evtree)

levels(local_to_g$proxy_g_evtree_g)[n_top + 1] <- levels(local_pop_g$proxy_g_evtree_g)[n_top + 1]

g_proxy <- local_to_g %>% 
  mutate(code = paste0(x2, '_', as.numeric(proxy_g_evtree_g))) %>% 
  ggplot(aes(x = x1, y = proxy_vuln,
             group = factor(code),
             color = factor(proxy_g_evtree_g))) + 
  geom_line(aes(x = x1, y = proxy_vuln_t,
                lty = factor(x2), group = factor(x2)),
                color = 'black', size = 0.8) +
  geom_line(size = 3, alpha = 0.78, lineend = "round", linejoin = "round") + 
  theme_classic() + 
  labs(x = latex2exp::TeX('$x_1$'),
       y = the_y_label_top,
       title = paste0('Scenario ', pop_id)) + 
  scale_color_manual(values = RColorBrewer::brewer.pal(n_bottom + n_top + 1, 'Spectral')  %>% colorspace::darken(0.05),
                     name = latex2exp::TeX('$\\widehat{\\Delta}^{ev}_{proxy}(\\textbf{x})$'),
                     labels = levels(local_to_g$proxy_g_evtree_g) %>% as.numeric %>% round(2)) + 
  scale_linetype_manual(values = c('12',  '21', '32', 'solid'), name = latex2exp::TeX('$x_2$')) +
  the_y_scale_top + 
  # geom_abline(slope = 0, intercept = 0, lty = '34', color= 'black', size= 0.7, alpha = 0.2) + 
  scale_x_continuous(labels = NULL, breaks = c(-3:3)*3 + 1) + # see above 
  guides(
    linetype = guide_legend(order = 1), # x2 legend on top
    color = guide_legend(order = 2)    # k legend below x2
  )

g_population <- local_pop_g %>%  
  ggplot(aes(y = factor(X2), x = X1,
             color = factor(proxy_g_evtree_g),
             fill = factor(proxy_g_evtree_g))) + 
  geom_jitter(#position = position_identity(), 
              width = 0, height = 0.4, alpha = 0.2) + 
  scale_color_manual(values = RColorBrewer::brewer.pal(n_bottom + n_top + 1, 'Spectral') %>% colorspace::darken(0.05),
                     name = latex2exp::TeX('$\\widehat{\\Delta}^{ev}_{proxy}(\\textbf{x})$'),
                     labels = levels(local_to_g$proxy_g_evtree) %>% as.numeric %>% round(1)) + 
  scale_fill_brewer(palette = 'Spectral', name = latex2exp::TeX('$k$')) + 
  theme_classic() + 
  the_y_scale +
  scale_x_continuous(breaks = c(-3:3)*3 + 1, limits = c(-5, 8)) + 
  labs(x = latex2exp::TeX("$x_1$") ,
       y = the_y_label) + 
  theme( axis.title.y = element_text(
      margin = margin(t = 50), # Add padding
    )) 
ggpubr::ggarrange(g_proxy, g_population, 
                  nrow = 2, common.legend = T,
                  legend = 'right',
                  heights = c(4, 3),
                  align = "v") 
}) %>%  
  ggpubr::ggarrange(plotlist = .,
                    ncol = 3,
                    widths = c(6, 5, 5)) %>% 
  ggsave(filename = "figs/graph_proxy_clusters_and_pop_scenario.png",
       plot = .,
       height = 6.25,
       width = 11.50,
       units = "in",
       device = "png", dpi = 500)
Figure 5.2: Partition of proxy vulnerability across the three scenarios (columns) in the example. The top row compares theoretical proxy vulnerability (black) with estimated values from the lightgbm, the latter colored by the group formed by the tree. The bottom row shows the partition in the \((x_1, x_2)\) domain, with noise added to~\(x_2\) for clarity.

5.2 Experiment on partioning proxy vulnerability for scenario 1

To support the partitioning methodology proposed in section 6 of the main paper, we present a simulation study to identify best practices and gain insights into optimal implementation.

Following the set of equations of Scenario 1 in sec-simul-dataset, we simulate \(M = 100\) set of \(N = 3000\) samples split into \((N_{\text{train}}, N_{\text{valid}}, N_{\text{test}}) = (2000, 500, 500)\) for train, validation, and test. Our aim is to assess the capacity of the methodology to recover proxy-vulnerable distinct subpopulations, precisely identify the most at-risk groups, and predict accurately the proxy vulnerability. We compute the BIC (under Gaussian assumption) of the estimated proxy vulnerability as compared with the test theoretical proxy vulnerability, and the accuracy in the partitioning as compared with the true proxy vulnerable groups : the eight subpopulations formed by the crossing of \(\{X_1 \leq 1, X_1 > 1\}\) and \(X_2 \in \{1, 2, 3, 4\}\).

On each sample set, we estimated unaware and aware premiums using the methodology described sec-training and we computed proxy vulnerability \(\widehat{\Delta}_{\text{proxy}}(x_1, x_2)\) via Eq. 1 of the main paper. We partitioned the feature space using \((X_1, X_2, D)\). Models used rpart (locally optimal) and evtree (globally optimal) regression trees, with hyperparameters tuned via validation BIC (Gaussian assumption): minimum leaf size proportion \(w \in \{0.03, 0.05\}\), tree depth \(d \in \{3, 4\}\), and complexity parameter \(\alpha \in \{1, 2\}\). For rpart, we pruned a deep tree to minimize validation error. We compared performance under fixed (\(k = 8\)) and optimized leaf counts, retaining the best model per case for each implementation (four total).

Code for the partition experiment.
if (!file.exists("preds/proxy_sims_results.csv")) {
  source("___evtree_experiment.R")
  
  # Load data
  preds_sims_stats <- jsonlite::fromJSON("preds/preds_sims_stats.json")
  
  # Process data
  proxy_sims_results <- process_data_evtree_experiment(preds_sims_stats)
  
  # Save results to CSV
  write.csv(proxy_sims_results, "preds/proxy_sims_results.csv", row.names = FALSE)
  cat("Results saved to preds/proxy_sims_results.csv")

  } else {
  proxy_sims_results <- read_csv("preds/proxy_sims_results.csv")
  cat("Results read from preds/proxy_sims_results.csv")
}
Results read from preds/proxy_sims_results.csv

The following table presents the results of our experiment. When the number of groups is correctly set to \(k = 8\), the \(8\times8\) accuracy and relaxed \(8\times8\) accuracy (which counts adjacent diagonals in the confusion matrix as correct) confirm that the method effectively groups individuals, with little difference between and . However, when \(k\) is unknown, \(k = 8\) was never seen as optimal. Validation metrics (based on estimated proxy vulnerability) favor larger \(k\), while oracle performance (based on theoretical proxy vulnerability) suggests better \(R^2\) at \(k = 8\). Whether \(k\) is known or not, partitioning identifies truly vulnerable individuals (top 12% by theoretical proxy vulnerability) with over 93% precision, fulfilling its primary objective. Since \(k\) is unknown in practice, strong regularization is essential to prevent excessive partitioning.

Code for producing the experimental result table
summary_results <- proxy_sims_results %>% 
  summarise(
  # Metrics for "known k" (k = 8) evtree
  k_8 = mean(proxy_sims_results$num_leaf_8, na.rm = TRUE),
  k_8_sd = sd(proxy_sims_results$num_leaf_8, na.rm = TRUE),
  minbucket_8 = mean(proxy_sims_results$minbucket_8, na.rm = TRUE),
  minbucket_8_sd = sd(proxy_sims_results$minbucket_8, na.rm = TRUE),
  maxdepth_8 = mean(proxy_sims_results$maxdepth_8, na.rm = TRUE),
  maxdepth_8_sd = sd(proxy_sims_results$maxdepth_8, na.rm = TRUE),
  alpha_8 = mean(proxy_sims_results$alpha_8, na.rm = TRUE),
  alpha_8_sd = sd(proxy_sims_results$alpha_8, na.rm = TRUE),
  validation_mse_8 = mean(proxy_sims_results$validation_mse_8, na.rm = TRUE),
  validation_mse_8_sd = sd(proxy_sims_results$validation_mse_8, na.rm = TRUE),
  oracle_r2_8 = mean(proxy_sims_results$r2_oracle_8, na.rm = TRUE),
  oracle_r2_8_sd = sd(proxy_sims_results$r2_oracle_8, na.rm = TRUE),
  acc_test_8 = mean(proxy_sims_results$accuracy_test_8, na.rm = TRUE),
  acc_test_8_sd = sd(proxy_sims_results$accuracy_test_8, na.rm = TRUE),
  relaxed_acc_test_8 = mean(proxy_sims_results$relaxed_accuracy_test_8,
                            na.rm = TRUE),
  relaxed_acc_test_8_sd = sd(proxy_sims_results$relaxed_accuracy_test_8,
                             na.rm = TRUE),
  top_acc_8 = mean(proxy_sims_results$top_acc_8, na.rm = TRUE),
  top_acc_8_sd = sd(proxy_sims_results$top_acc_8, na.rm = TRUE),
  bottom_acc_8 = mean(proxy_sims_results$bottom_acc_8, na.rm = TRUE),
  bottom_acc_8_sd = sd(proxy_sims_results$bottom_acc_8, na.rm = TRUE),
  time_8 = mean(proxy_sims_results$time_8, na.rm = TRUE),  # Time assumed same for any/known
  time_8_sd = sd(proxy_sims_results$time_8, na.rm = TRUE),
  
  
    
  # Metrics for "known k" (k = 8) rpart
  k_r8 = mean(proxy_sims_results$num_leaf_r8, na.rm = TRUE),
  k_r8_sd = sd(proxy_sims_results$num_leaf_r8, na.rm = TRUE),
  minbucket_r8 = mean(proxy_sims_results$minbucket_r8, na.rm = TRUE),
  minbucket_r8_sd = sd(proxy_sims_results$minbucket_r8, na.rm = TRUE),
  maxdepth_r8 = mean(proxy_sims_results$maxdepth_r8, na.rm = TRUE),
  maxdepth_r8_sd = sd(proxy_sims_results$maxdepth_r8, na.rm = TRUE),
  cp_8 = mean(proxy_sims_results$cp_8, na.rm = TRUE),
  cp_8_sd = sd(proxy_sims_results$cp_8, na.rm = TRUE),
  validation_mse_r8 = mean(proxy_sims_results$validation_mse_r8, na.rm = TRUE),
  validation_mse_r8_sd = sd(proxy_sims_results$validation_mse_r8, na.rm = TRUE),
  oracle_r2_r8 = mean(proxy_sims_results$r2_oracle_r8, na.rm = TRUE),
  oracle_r2_r8_sd = sd(proxy_sims_results$r2_oracle_r8, na.rm = TRUE),
  acc_test_r8 = mean(proxy_sims_results$accuracy_test_r8, na.rm = TRUE),
  acc_test_r8_sd = sd(proxy_sims_results$accuracy_test_r8, na.rm = TRUE),
  relaxed_acc_test_r8 = mean(proxy_sims_results$relaxed_accuracy_test_r8, na.rm = TRUE),
  relaxed_acc_test_r8_sd = sd(proxy_sims_results$relaxed_accuracy_test_r8, na.rm = TRUE),
  top_acc_r8 = mean(proxy_sims_results$top_acc_r8, na.rm = TRUE),
  top_acc_r8_sd = sd(proxy_sims_results$top_acc_r8, na.rm = TRUE),
  bottom_acc_r8 = mean(proxy_sims_results$bottom_acc_r8, na.rm = TRUE),
  bottom_acc_r8_sd = sd(proxy_sims_results$bottom_acc_r8, na.rm = TRUE),
  time_r8 = mean(proxy_sims_results$time_r8, na.rm = TRUE),  # Time assumed same for any/known
  time_r8_sd = sd(proxy_sims_results$time_r8, na.rm = TRUE),

  # Metrics for "any k" rpart
  k_rany = mean(proxy_sims_results$num_leaf_rany, na.rm = TRUE),
  k_rany_sd = sd(proxy_sims_results$num_leaf_rany, na.rm = TRUE),
  minbucket_rany = mean(proxy_sims_results$minbucket_rany, na.rm = TRUE),
  minbucket_rany_sd = sd(proxy_sims_results$minbucket_rany, na.rm = TRUE),
  maxdepth_rany = mean(proxy_sims_results$maxdepth_rany, na.rm = TRUE),
  maxdepth_rany_sd = sd(proxy_sims_results$maxdepth_rany, na.rm = TRUE),
  cp_any = mean(proxy_sims_results$cp_any, na.rm = TRUE),
  cp_any_sd = sd(proxy_sims_results$cp_any, na.rm = TRUE),
  validation_mse_rany = mean(proxy_sims_results$validation_mse_rany, na.rm = TRUE),
  validation_mse_rany_sd = sd(proxy_sims_results$validation_mse_rany, na.rm = TRUE),
  oracle_r2_rany = mean(proxy_sims_results$r2_oracle_rany, na.rm = TRUE),
  oracle_r2_rany_sd = sd(proxy_sims_results$r2_oracle_rany, na.rm = TRUE),
  top_acc_rany = mean(proxy_sims_results$top_acc_rany, na.rm = TRUE),
  top_acc_rany_sd = sd(proxy_sims_results$top_acc_rany, na.rm = TRUE),
  bottom_acc_rany = mean(proxy_sims_results$bottom_acc_rany, na.rm = TRUE),
  bottom_acc_rany_sd = sd(proxy_sims_results$bottom_acc_rany, na.rm = TRUE),
  time_rany = mean(proxy_sims_results$time_rany, na.rm = TRUE),  # Time is shared, no distinction
  time_rany_sd = sd(proxy_sims_results$time_rany, na.rm = TRUE),
  
    # Metrics for "any k" evtree
  k_any = mean(proxy_sims_results$num_leaf_any, na.rm = TRUE),
  k_any_sd = sd(proxy_sims_results$num_leaf_any, na.rm = TRUE),
  minbucket_any = mean(proxy_sims_results$minbucket_any, na.rm = TRUE),
  minbucket_any_sd = sd(proxy_sims_results$minbucket_any, na.rm = TRUE),
  maxdepth_any = mean(proxy_sims_results$maxdepth_any, na.rm = TRUE),
  maxdepth_any_sd = sd(proxy_sims_results$maxdepth_any, na.rm = TRUE),
  alpha_any = mean(proxy_sims_results$alpha_any, na.rm = TRUE),
  alpha_any_sd = sd(proxy_sims_results$alpha_any, na.rm = TRUE),
  validation_mse_any = mean(proxy_sims_results$validation_mse_any, na.rm = TRUE),
  validation_mse_any_sd = sd(proxy_sims_results$validation_mse_any, na.rm = TRUE),
  oracle_r2_any = mean(proxy_sims_results$r2_oracle_any, na.rm = TRUE),
  oracle_r2_any_sd = sd(proxy_sims_results$r2_oracle_any, na.rm = TRUE),
  top_acc_any = mean(proxy_sims_results$top_acc_any, na.rm = TRUE),
  top_acc_any_sd = sd(proxy_sims_results$top_acc_any, na.rm = TRUE),
  bottom_acc_any = mean(proxy_sims_results$bottom_acc_any, na.rm = TRUE),
  bottom_acc_any_sd = sd(proxy_sims_results$bottom_acc_any, na.rm = TRUE),
  time_any = mean(proxy_sims_results$time_any, na.rm = TRUE),  # Time is shared, no distinction
  time_any_sd = sd(proxy_sims_results$time_any, na.rm = TRUE),
  
  # Add new metrics for eo, e8, ro, r8
    recall_top_pct_eo = mean(proxy_sims_results$recall_top_pct_eo, na.rm = TRUE),
    recall_top_pct_eo_sd = sd(proxy_sims_results$recall_top_pct_eo, na.rm = TRUE),
    prec_top_pct_eo = mean(proxy_sims_results$prec_top_pct_eo, na.rm = TRUE),
    prec_top_pct_eo_sd = sd(proxy_sims_results$prec_top_pct_eo, na.rm = TRUE),
    acc_top_pct_eo = mean(proxy_sims_results$acc_top_pct_eo, na.rm = TRUE),
    acc_top_pct_eo_sd = sd(proxy_sims_results$acc_top_pct_eo, na.rm = TRUE),
    effpct_top_pct_eo = mean(proxy_sims_results$effpct_top_pct_eo, na.rm = TRUE),
    effpct_top_pct_eo_sd = sd(proxy_sims_results$effpct_top_pct_eo, na.rm = TRUE),
    
    recall_top_pct_e8 = mean(proxy_sims_results$recall_top_pct_e8, na.rm = TRUE),
    recall_top_pct_e8_sd = sd(proxy_sims_results$recall_top_pct_e8, na.rm = TRUE),
    prec_top_pct_e8 = mean(proxy_sims_results$prec_top_pct_e8, na.rm = TRUE),
    prec_top_pct_e8_sd = sd(proxy_sims_results$prec_top_pct_e8, na.rm = TRUE),
    acc_top_pct_e8 = mean(proxy_sims_results$acc_top_pct_e8, na.rm = TRUE),
    acc_top_pct_e8_sd = sd(proxy_sims_results$acc_top_pct_e8, na.rm = TRUE),
    effpct_top_pct_e8 = mean(proxy_sims_results$effpct_top_pct_e8, na.rm = TRUE),
    effpct_top_pct_e8_sd = sd(proxy_sims_results$effpct_top_pct_e8, na.rm = TRUE),
    
    recall_top_pct_ro = mean(proxy_sims_results$recall_top_pct_ro, na.rm = TRUE),
    recall_top_pct_ro_sd = sd(proxy_sims_results$recall_top_pct_ro, na.rm = TRUE),
    prec_top_pct_ro = mean(proxy_sims_results$prec_top_pct_ro, na.rm = TRUE),
    prec_top_pct_ro_sd = sd(proxy_sims_results$prec_top_pct_ro, na.rm = TRUE),
    acc_top_pct_ro = mean(proxy_sims_results$acc_top_pct_ro, na.rm = TRUE),
    acc_top_pct_ro_sd = sd(proxy_sims_results$acc_top_pct_ro, na.rm = TRUE),
    effpct_top_pct_ro = mean(proxy_sims_results$effpct_top_pct_ro, na.rm = TRUE),
    effpct_top_pct_ro_sd = sd(proxy_sims_results$effpct_top_pct_ro, na.rm = TRUE),
    
    recall_top_pct_r8 = mean(proxy_sims_results$recall_top_pct_r8, na.rm = TRUE),
    recall_top_pct_r8_sd = sd(proxy_sims_results$recall_top_pct_r8, na.rm = TRUE),
    prec_top_pct_r8 = mean(proxy_sims_results$prec_top_pct_r8, na.rm = TRUE),
    prec_top_pct_r8_sd = sd(proxy_sims_results$prec_top_pct_r8, na.rm = TRUE),
    acc_top_pct_r8 = mean(proxy_sims_results$acc_top_pct_r8, na.rm = TRUE),
    acc_top_pct_r8_sd = sd(proxy_sims_results$acc_top_pct_r8, na.rm = TRUE),
    effpct_top_pct_r8 = mean(proxy_sims_results$effpct_top_pct_r8, na.rm = TRUE),
    effpct_top_pct_r8_sd = sd(proxy_sims_results$effpct_top_pct_r8, na.rm = TRUE)
)

# Step 3: Transform results into a cleaner format
summary_table <- data.frame(
  Metric = c(
    "Number of Leaves (k)", "Min Split", "Max Depth", "Alpha", 
    "Validation BIC", "Oracle R²", "Accuracy Test", "Relaxed Accuracy Test",
    "Top Accuracy", "Bottom Accuracy", "Detect recall", "Detect effective %", "Time"
  ),
  Rpart_Known_k_Mean = c(
    summary_results$k_r8, summary_results$minbucket_r8, summary_results$maxdepth_r8,
    summary_results$cp_8, summary_results$validation_mse_r8, summary_results$oracle_r2_r8,
    summary_results$acc_test_r8, summary_results$relaxed_acc_test_r8,
    summary_results$top_acc_r8, summary_results$bottom_acc_r8, summary_results$recall_top_pct_r8, summary_results$effpct_top_pct_r8, summary_results$time_r8
  ),
  Rpart_Known_k_SD = c(
    summary_results$k_r8_sd, summary_results$minbucket_r8_sd, summary_results$maxdepth_r8_sd,
    summary_results$cp_8_sd, summary_results$validation_mse_r8_sd, summary_results$oracle_r2_r8_sd,
    summary_results$acc_test_r8_sd, summary_results$relaxed_acc_test_r8_sd,
    summary_results$top_acc_r8_sd, summary_results$bottom_acc_r8_sd,  summary_results$recall_top_pct_r8_sd, summary_results$effpct_top_pct_r8_sd, summary_results$time_r8_sd
  ),
  Ev_Known_k_Mean = c(
    summary_results$k_8, summary_results$minbucket_8, summary_results$maxdepth_8,
    summary_results$alpha_8, summary_results$validation_mse_8, summary_results$oracle_r2_8,
    summary_results$acc_test_8, summary_results$relaxed_acc_test_8,
    summary_results$top_acc_8, summary_results$bottom_acc_8,summary_results$recall_top_pct_e8, summary_results$effpct_top_pct_e8,  summary_results$time_8
  ),
  Ev_Known_k_SD = c(
    summary_results$k_8_sd, summary_results$minbucket_8_sd, summary_results$maxdepth_8_sd,
    summary_results$alpha_8_sd, summary_results$validation_mse_8_sd, summary_results$oracle_r2_8_sd,
    summary_results$acc_test_8_sd, summary_results$relaxed_acc_test_8_sd,
    summary_results$top_acc_8_sd, summary_results$bottom_acc_8_sd,summary_results$recall_top_pct_e8_sd, summary_results$effpct_top_pct_e8_sd,  summary_results$time_8_sd
  ),
   Rpart_Any_k_Mean = c(
    summary_results$k_rany, summary_results$minbucket_rany, summary_results$maxdepth_rany,
    summary_results$cp_any, summary_results$validation_mse_rany, summary_results$oracle_r2_rany,
    NA, NA,  # Accuracy metrics not available for Any k
    summary_results$top_acc_rany, summary_results$bottom_acc_rany, summary_results$recall_top_pct_ro, summary_results$effpct_top_pct_ro, summary_results$time_rany
  ),
  Rpart_Any_k_SD = c(
    summary_results$k_rany_sd, summary_results$minbucket_rany_sd, summary_results$maxdepth_rany_sd,
    summary_results$cp_any_sd, summary_results$validation_mse_rany_sd, summary_results$oracle_r2_rany_sd,
    NA, NA,  # Accuracy metrics not available for Any k
    summary_results$top_acc_rany_sd, summary_results$bottom_acc_rany_sd, summary_results$recall_top_pct_ro_sd, summary_results$effpct_top_pct_ro_sd, summary_results$time_rany_sd
  ),
  Ev_Any_k_Mean = c(
    summary_results$k_any, summary_results$minbucket_any, summary_results$maxdepth_any,
    summary_results$alpha_any, summary_results$validation_mse_any, summary_results$oracle_r2_any,
    NA, NA,  # Accuracy metrics not available for Any k
    summary_results$top_acc_any, summary_results$bottom_acc_any,summary_results$recall_top_pct_eo, summary_results$effpct_top_pct_eo,  summary_results$time_any
  ),
  Ev_Any_k_SD = c(
    summary_results$k_any_sd, summary_results$minbucket_any_sd, summary_results$maxdepth_any_sd,
    summary_results$alpha_any_sd, summary_results$validation_mse_any_sd, summary_results$oracle_r2_any_sd,
    NA, NA,  # Accuracy metrics not available for Any k
    summary_results$top_acc_any_sd, summary_results$bottom_acc_any_sd, summary_results$recall_top_pct_eo_sd, summary_results$effpct_top_pct_eo_sd,  summary_results$time_any_sd
  ),
  group = c('Hyperparam.', 'Hyperparam.', 'Hyperparam.', 'Hyperparam.', 
            'Efficiency', 'Oracle perf.', 'Oracle perf.', 'Oracle perf.', 'Oracle perf.', 'Oracle perf.',
            'Efficiency', 'Efficiency', 'Efficiency') |> factor(levels = c('Hyperparam.', 'Efficiency', 'Oracle perf.'))
)

library(knitr)
library(kableExtra)

## round
summary_table[, 2:9] <- round(summary_table[,2:9], 3)
table_to_g <- summary_table[,c('Metric', 
                               "Rpart_Known_k_Mean", "Rpart_Known_k_SD", 
                               "Ev_Known_k_Mean", "Ev_Known_k_SD", 
                               "Rpart_Any_k_Mean", "Rpart_Any_k_SD", 
                               "Ev_Any_k_Mean", "Ev_Any_k_SD", "group")] |> 
  arrange(group)

# Add group column to the beginning
table_to_g <- table_to_g |> relocate(group)

# Row breaks after last row of each group
group_lines <- c(5, 9)

# Create table with custom header
kbl(table_to_g |> dplyr::select(-group), col.names = NULL,
    caption = "Results of experimental testing. For each of the $N = 100$ samples, we obtained the best regression trees when forcing $k = 8$ or when $k$ was part of the tuned hyperparameters.",
    label = "experiment") %>%
  add_header_above(c(
    "Metric" = 1,
    "Mean" = 1, "SD" = 1,
    "Mean" = 1, "SD" = 1,
    "Mean" = 1, "SD" = 1,
    "Mean" = 1, "SD" = 1
  )) %>%
  add_header_above(c(
    " " = 1,
    "Greedy tree \n (rpart)" = 2,
    "Optimal tree \n (evtree)" = 2,
    "Greedy tree \n (rpart)" = 2,
    "Optimal tree \n (evtree)" = 2
  )) %>% add_header_above(c(
    " " = 1,
    "Known k" = 4,
    "Unknown k" = 4
  ), escape = FALSE) %>%
  group_rows(index = table(table_to_g$group)) %>%
  row_spec(group_lines, extra_css = "border-top: 2px solid black;") %>%
  kable_styling(full_width = FALSE, bootstrap_options = c("striped", "hover"))
Known k
Unknown k
Greedy tree
(rpart)
Optimal tree
(evtree)
Greedy tree
(rpart)
Optimal tree
(evtree)
Metric
Mean
SD
Mean
SD
Mean
SD
Mean
SD
Results of experimental testing. For each of the $N = 100$ samples, we obtained the best regression trees when forcing $k = 8$ or when $k$ was part of the tuned hyperparameters.
Hyperparam.
Number of Leaves (k) 8.000 0.000 8.000 0.000 14.290 1.192 12.290 1.908
Min Split 0.037 0.010 0.037 0.010 0.032 0.005 0.031 0.005
Max Depth 3.150 0.359 3.150 0.359 3.990 0.100 3.970 0.171
Alpha 0.005 0.003 1.470 0.502 0.000 0.000 1.390 0.490
Efficiency
Validation BIC 90.892 155.067 Inf NaN 39.851 165.704 7.816 158.082
Detect recall 0.951 0.118 0.959 0.104 0.907 0.153 0.926 0.136
Detect effective % 0.171 0.068 0.183 0.072 0.152 0.042 0.161 0.052
Time 0.016 0.007 23.083 7.086 0.015 0.007 43.037 14.681
Oracle perf.
Oracle R² 0.905 0.036 0.904 0.038 0.897 0.034 0.894 0.035
Accuracy Test 0.601 0.165 0.649 0.167 NA NA NA NA
Relaxed Accuracy Test 0.971 0.051 0.976 0.047 NA NA NA NA
Top Accuracy 0.941 0.076 0.932 0.079 0.957 0.064 0.951 0.067
Bottom Accuracy 0.942 0.077 0.939 0.076 0.967 0.056 0.964 0.051
Grubinger, Thomas, Achim Zeileis, and Karl-Peter Pfeiffer. 2014. “Evtree: Evolutionary Learning of Globally Optimal Classification and Regression Trees in r.” Journal of Statistical Software 61 (1): 1–29. https://doi.org/10.18637/jss.v061.i01.